#!/usr/bin/env python3

# Regblock generation script

import argparse
import re
import sys
import textwrap
import yaml
from collections import OrderedDict
from math import log2, ceil

# Regblock data structures and parser/loader

class RegBlock:
	bus_types = ["apb"]
	field_properties = ["name", "b", "access", "info", "rst", "concat"]
	def __init__(self, name, w_data, w_addr, bus, info):
		self.name = name
		self.w_data = w_data
		self.w_addr = w_addr
		if bus not in self.bus_types:
			raise Exception("Unknown bus type: {}".format(bus))
		self.bus = bus
		self.info = info
		self.regs = []

	def add(self, reg):
		self.regs.append(reg)

	@staticmethod
	def load(file):
		y = yaml.load(file.read(), Loader=yaml.FullLoader)
		if "bus" not in y:
			sys.exit("Must specify bus type with \"bus: x\"")
		if "data" not in y:
			sys.exit("Must specify data width with \"data: x\"")
		if "addr" not in y:
			sys.exit("Must specify address width with \"addr: x\"")
		if "name" not in y:
			sys.exit("Must specify regblock name with \"name: x\"")
		rb = RegBlock(y["name"], y["data"], y["addr"], y["bus"], y["info"] if "info" in y else None)
		param_dict = {"W_DATA": y["data"]}
		if ("params" in y):
			param_dict = {**param_dict, **y["params"]}
		param_resolve = lambda x: x if type(x) is int else int(eval(x, {**param_dict}))
		# Expand any generate blocks (one level only)
		reglist = []
		for i, rspec in enumerate(y["regs"]):
			if "generate" not in rspec:
				reglist.append(rspec)
				continue
			yaml_lines = []
			emit = lambda x: yaml_lines.append(str(x))
			exec(rspec["generate"], {"_": emit, **param_dict}, dict())
			newregs = yaml.load("\n".join(yaml_lines), Loader=yaml.FullLoader)
			if newregs is not None:
				reglist.extend(newregs)
		# Then process the expanded reglist
		for rspec in reglist:
			reg = Register(rspec["name"], rb.w_data, rspec["info"] if "info" in rspec else None, rspec["wstb"] if "wstb" in rspec else None)
			rb.add(reg)
			for fspec in rspec["bits"]:
				for key in fspec:
					if key not in RegBlock.field_properties:
						raise Exception("'{}' is not a valid property for a field.".format(key))
				bitrange = fspec["b"]
				if type(bitrange) is int:
					bitrange = [bitrange, bitrange]
				reg.add(Field(
					fspec["name"] if "name" in fspec else "",
					param_resolve(bitrange[0]),
					param_resolve(bitrange[1]),
					fspec["access"],
					fspec["rst"] if "rst" in fspec else 0,
					fspec["info"] if "info" in fspec else None,
					fspec["concat"] if "concat" in fspec else None
				))
		return rb

	def accept(self, visitor):
		visitor.pre(self)
		for reg in self.regs:
			reg.accept(visitor)
		visitor.post(self)

class Register:
	def __init__(self, name, width, info, wstrobe=None):
		assert(width > 0)
		self.name = name
		self.width = width
		self.info = info
		self.occupancy = [None] * width
		self.fields = OrderedDict()
		self.wstrobe = wstrobe

	def add(self, field):
		if field.lsb >= self.width or field.msb >= self.width or field.lsb < 0 or field.msb < 0:
			raise Exception("Field {} extends outside of register {}".format(field.name, self.name))
		if field.name in self.fields:
			raise Exception("{} already has a field called \"{}\"".format(self.name, field.name))
		for i in range(field.lsb, field.msb + 1):
			if self.occupancy[i] is not None:
				raise Exception("Field {} overlaps {} in register {}".format(field.name, self.occupancy[i], self.name))
			self.occupancy[i] = field.name
		self.fields[field.name] = field
		field.parent = self

	def accept(self, visitor):
		visitor.pre(self)
		for field in self.fields.values():
			field.accept(visitor)
		visitor.post(self)


class Field:
	access_types = ["ro", "rov", "wo", "rw", "rf", "wf", "rwf", "sc", "w1c"]
	def __init__(self, name, msb, lsb, access, resetval=0, info="", concat=None, parent=None):
		if lsb > msb:
			raise Exception("Field width must be >= 0 in field {}".format(name))
		if access not in self.access_types:
			raise Exception("Unknown access type: {}. Recognised types: {}".format(access, ", ".join(self.access_types)))
		if access in ["sc", "w1c"] and msb != lsb:
			raise Exception("Field width must be 1 for access type '{}'".format(access))
		self.name = name
		self.msb = msb
		self.lsb = lsb
		self.access = access
		self.resetval = resetval
		self.info = info
		self.concat = concat
		self.parent = parent

	@property
	def width_decl(self):
		if self.msb == self.lsb:
			return ""
		else:
			return "[{}:0]".format(self.msb - self.lsb)

	@property
	def width(self):
		return self.msb - self.lsb + 1

	@property
	def fullname(self):
		if self.name == "":
			return self.parent.name
		else:
			return "{}_{}".format(self.parent.name, self.name)

	def accept(self, visitor):
		visitor.pre(self)

# Base class for various useful lumps of functionality
# which are applied as traversals on a regblock description tree
class RegBlockVisitor:
	def pre(self, x):
		if type(x) is RegBlock:
			self.pre_regblock(x)
		elif type(x) is Register:
			self.pre_register(x)
		elif type(x) is Field:
			self.pre_field(x)
		else:
			raise TypeError()

	def post(self, x):
		if type(x) is RegBlock:
			self.post_regblock(x)
		elif type(x) is Register:
			self.post_register(x)
		elif type(x) is Field:
			self.post_field(x)
		else:
			raise TypeError()

	def pre_regblock(self, rb):
		raise NotImplementedError()

	def pre_register(self, r):
		raise NotImplementedError()

	def pre_field(self, f):
		raise NotImplementedError()

	def post_regblock(self, rb):
		raise NotImplementedError()

	def post_register(self, r):
		raise NotImplementedError()

	def post_field(self, f):
		raise NotImplementedError()


# Verilog generation

def width2str(w):
	return "[{}:0]".format(w - 1)

def c_block_comment(lines, width=80, align="<"):
	blines = []
	blines.append("/" + (width - 1) * "*")
	line_fmt = "* {:" + align + str(width - 4) + "} *"
	for l in lines:
		blines.append(line_fmt.format(l.strip()))
	blines.append((width - 1) * "*" + "/")
	return blines

def c_line_comment(line):
	if not hasattr(c_line_comment, "wrap"):
		c_line_comment.wrap = textwrap.TextWrapper(width=80, initial_indent="// ", subsequent_indent="// ")
	return "\n".join(c_line_comment.wrap.wrap(line.strip()))

class Verilog:
	def __init__(self):
		self.header = []
		self.ports = []
		self.decls = []
		self.logic_comb = []
		self.logic_rst = []
		self.logic_clk = []

	def __add__(self, other):
		new = Verilog()
		new.header.extend(self.header)
		for n, s, o in zip(
			[new.ports, new.decls, new.logic_comb, new.logic_rst, new.logic_clk],
			[self.ports, self.decls, self.logic_comb, self.logic_rst, self.logic_clk],
			[other.ports, other.decls, other.logic_comb, other.logic_rst, other.logic_clk]):
			n.extend(s)
			n.extend(o)
		return new

	def __str__(self):
		strs = []
		strs.extend(self.header)
		strs.extend("\t" + s + ("" if s == "" or s.startswith("//") else ",")  for s in self.ports[:-1])
		if len(self.ports) > 0:
			strs.append("\t" + self.ports[-1])
		strs.append(");")
		strs.append("")
		strs.extend(self.decls);
		strs.append("")
		strs.append("always @ (*) begin")
		strs.extend("\t" + s for s in self.logic_comb)
		strs.append("end")
		strs.append("")
		strs.append("always @ (posedge clk or negedge rst_n) begin")
		strs.append("\tif (!rst_n) begin")
		strs.extend("\t\t" + s for s in self.logic_rst)
		strs.append("\tend else begin")
		strs.extend("\t\t" + s for s in self.logic_clk)
		strs.append("\tend")
		strs.append("end")
		strs.append("")
		strs.append("endmodule\n")
		return "\n".join(strs)


class VerilogWriter(RegBlockVisitor):
	def __init__(self):
		self.v = Verilog()
		self.regname = None
		self.concat = OrderedDict()
		self.wstrobe = OrderedDict()

	def pre_regblock(self, rb):
		v = self.v
		v.header.extend(c_block_comment(["AUTOGENERATED BY REGBLOCK", "Do not edit manually.",
			"Edit the source file (or regblock utility) and regenerate."], align = "^"))
		v.header.append("")
		v.header.append(c_line_comment("{:<20} : {}".format("Block name", rb.name)))
		v.header.append(c_line_comment("{:<20} : {}".format("Bus type", rb.bus)))
		v.header.append(c_line_comment("{:<20} : {}".format("Bus data width", rb.w_data)))
		v.header.append(c_line_comment("{:<20} : {}".format("Bus address width", rb.w_addr)))
		v.header.append("")
		if rb.info is not None:
			v.header.append(c_line_comment(rb.info))
			v.header.append("")
		v.header.append("module {}_regs (".format(rb.name))
		v.ports.extend(["input wire clk", "input wire rst_n",])
		addr_mask = (1 << 2 + ceil(log2(len(rb.regs)))) - 1
		addr_mask = addr_mask & (-1 << ceil(log2(rb.w_data / 8)))
		if rb.bus == "apb":
			v.ports.append("")
			v.ports.append("// APB Port")
			v.ports.append("input wire apbs_psel")
			v.ports.append("input wire apbs_penable")
			v.ports.append("input wire apbs_pwrite")
			v.ports.append("input wire {} apbs_paddr".format(width2str(rb.w_addr)))
			v.ports.append("input wire {} apbs_pwdata".format(width2str(rb.w_data)))
			v.ports.append("output wire {} apbs_prdata".format(width2str(rb.w_data)))
			v.ports.append("output wire apbs_pready")
			v.ports.append("output wire apbs_pslverr")
			v.decls.append("// APB adapter")
			v.decls.append("wire {} wdata = apbs_pwdata;".format(width2str(rb.w_data)))
			v.decls.append("reg {} rdata;".format(width2str(rb.w_data)))
			v.decls.append("wire wen = apbs_psel && apbs_penable && apbs_pwrite;")
			v.decls.append("wire ren = apbs_psel && apbs_penable && !apbs_pwrite;")
			v.decls.append("wire {} addr = apbs_paddr & {}'h{:x};".format(width2str(rb.w_addr), rb.w_addr, addr_mask))
			v.decls.append("assign apbs_prdata = rdata;")
			v.decls.append("assign apbs_pready = 1'b1;")
			v.decls.append("assign apbs_pslverr = 1'b0;")
			v.decls.append("")
		v.ports.append("")
		v.ports.append("// Register interfaces")
		for i, reg in enumerate(rb.regs):
			v.decls.append("localparam ADDR_{} = {};".format(reg.name.upper(), i * 4))
		v.decls.append("")
		for reg in rb.regs:
			v.decls.append("wire __{}_wen = wen && addr == ADDR_{};".format(reg.name, reg.name.upper()))
			v.decls.append("wire __{}_ren = ren && addr == ADDR_{};".format(reg.name, reg.name.upper()))
		v.logic_comb.append("case (addr)")
		for reg in rb.regs:
			v.logic_comb.append("\tADDR_{}: rdata = __{}_rdata;".format(reg.name.upper(), reg.name))
		v.logic_comb.append("\tdefault: rdata = {}'h0;".format(rb.w_data))
		v.logic_comb.append("endcase")

	def post_regblock(self, rb):
		for name, conns in self.concat.items():
			concat_name = "concat_{}_o".format(name)
			width = sum(f[0] for f in conns)
			self.v.ports.append("output wire [{}:0] {}".format(width - 1, concat_name))
			self.v.decls.append("assign {} = {{{}}};".format(concat_name, ", ".join(f[1] for f in reversed(conns))))
		max_strobe_index = OrderedDict()
		for name, conns in self.wstrobe.items():
			self.v.logic_rst.append("wstrobe_{} <= 1'b0;".format(name))
			self.v.logic_clk.append("wstrobe_{} <= {};".format(name, " || ".join(
				"__{}_wen".format(conn) for conn in conns)))
			if name.endswith("]"):
				shortname = name.split("[")[0]
				idx = int(name.split("[")[-1][:-1])
				if shortname in max_strobe_index:
					max_strobe_index[shortname] = max(max_strobe_index[shortname], idx)
				else:
					max_strobe_index[shortname] = idx
			else:
				self.v.ports.append("output reg wstrobe_{}".format(name))
		for name, max_idx in max_strobe_index.items():
			self.v.ports.append("output reg [{}:0] wstrobe_{}".format(max_idx, name))

	def pre_register(self, reg):
		v = self.v
		self.regname = reg.name
		v.decls.append("")
		rdata_conns = []
		empty_count = 0
		last_occupant = None
		for occupant in reversed(reg.occupancy):
			if occupant is None:
				empty_count += 1
			elif occupant != last_occupant:
				if empty_count > 0:
					rdata_conns.append("{}'h0".format(empty_count))
					empty_count = 0
				rdata_conns.append("{}_rdata".format(reg.fields[occupant].fullname))
			last_occupant = occupant
		if empty_count > 0:
			rdata_conns.append("{}'h0".format(empty_count))
		for field in reg.fields.values():
			lsb = reg.occupancy.index(field.name)
			msb = lsb - 1 + reg.occupancy.count(field.name)
			index = "[{}]".format(lsb) if msb == lsb else "[{}:{}]".format(msb, lsb)
			v.decls.append("wire {} {}_wdata = wdata{};".format(field.width_decl, field.fullname, index))
			v.decls.append("wire {} {}_rdata;".format(field.width_decl, field.fullname))
		v.decls.append("wire {} __{}_rdata = {{{}}};".format(width2str(reg.width), reg.name, ", ".join(rdata_conns)))

		if reg.wstrobe is not None:
			if not reg.wstrobe in self.wstrobe:
				self.wstrobe[reg.wstrobe] = []
			self.wstrobe[reg.wstrobe].append(reg.name)

	def post_register(self, reg):
		pass

	def pre_field(self, f):
		v = self.v
		rname = self.regname
		fname = f.fullname
		if f.access in ["rov", "rf", "rwf", "w1c"]:
			v.ports.append("input wire {} {}_i".format(f.width_decl, fname))
		if f.access in ["rov", "rf", "rwf"]:
			v.decls.append("assign {}_rdata = {}_i;".format(fname, fname))
		if f.access in ["wo", "rw", "wf", "rwf", "sc", "w1c"]:
			v.ports.append("output reg {} {}_o".format(f.width_decl, fname))
		if f.access in ["wf", "rwf"]:
			v.ports.append("output reg {}_wen".format(fname))
			v.logic_comb.append("{}_wen = __{}_wen;".format(fname, rname))
			v.logic_comb.append("{}_o = {}_wdata;".format(fname, fname))
		if f.access in ["rf", "rwf"]:
			v.ports.append("output reg {}_ren".format(fname))
			v.logic_comb.append("{}_ren = __{}_ren;".format(fname, rname))
		if f.access in ["rw"]:
			v.decls.append("assign {}_rdata = {}_o;".format(fname, fname))
		if f.access in ["rw", "wo"]:
			v.logic_rst.append("{}_o <= {}'h{:x};".format(fname, f.width, f.resetval))
			v.logic_clk.append("if (__{}_wen)".format(rname))
			v.logic_clk.append("\t{}_o <= {}_wdata;".format(fname, fname))
		if f.access in ["w1c"]:
			v.logic_rst.append("{} <= {}'h{:x};".format(fname, f.width, f.resetval))
			v.decls.append("reg {} {};".format(f.width_decl, fname))
			v.decls.append("assign {}_rdata = {};".format(fname, fname))
			v.logic_clk.append("{0} <= ({0} && !(__{1}_wen && {0}_wdata)) || {0}_i;".format(fname, rname))
			v.logic_comb.append("{0}_o = {0};".format(fname))
		if f.access in ["sc"]:
			v.logic_comb.append("{0}_o = {0}_wdata & {{{1}{{__{2}_wen}}}};".format(fname, f.width, rname))
		if f.access in ["ro", "wo", "wf", "sc"]:
			v.decls.append("assign {}_rdata = {}'h{:x};".format(fname, f.width, f.resetval))

		if f.concat is not None:
			if not f.access in ["wo", "rw", "wf", "rwf", "sc"]:
				raise Exception("concat specified for port with no regblock output")
			if not f.concat in self.concat:
				self.concat[f.concat] = []
			self.concat[f.concat].append((f.width, "{}_o".format(fname)))


# C header generation

class HeaderWriter(RegBlockVisitor):
	def __init__(self):
		self.lines = []
		self.blockname = None
		self.regname = None

	def __str__(self):
		return "".join(l + "\n" for l in self.lines)

	def pre_regblock(self, rb):
		lines = self.lines
		self.blockname = rb.name
		lines.extend(c_block_comment(["AUTOGENERATED BY REGBLOCK", "Do not edit manually.",
			"Edit the source file (or regblock utility) and regenerate."], align = "^"))
		lines.append("")
		lines.append("#ifndef _{}_REGS_H_".format(rb.name.upper()))
		lines.append("#define _{}_REGS_H_".format(rb.name.upper()))
		lines.append("")
		lines.append(c_line_comment("{:<20} : {}".format("Block name", rb.name)))
		lines.append(c_line_comment("{:<20} : {}".format("Bus type", rb.bus)))
		lines.append(c_line_comment("{:<20} : {}".format("Bus data width", rb.w_data)))
		lines.append(c_line_comment("{:<20} : {}".format("Bus address width", rb.w_addr)))
		lines.append("")
		if rb.info is not None:
			lines.append(c_line_comment(rb.info))
			lines.append("")
		for i, reg in enumerate(rb.regs):
			lines.append("#define {}_{}_OFFS {}".format(rb.name.upper(), reg.name.upper(), i * 4))

	def post_regblock(self, rb):
		self.lines.append("")
		self.lines.append("#endif // _{}_REGS_H_".format(rb.name.upper()))

	def pre_register(self, reg):
		lines = self.lines
		lines.append("")
		lines.extend(c_block_comment([reg.name.upper()], align = "^"))
		lines.append("")
		if reg.info is not None:
			lines.append(c_line_comment(reg.info))
			lines.append("")
		self.regname = reg.name

	def post_register(self, reg):
		pass

	def pre_field(self, f):
		fname = f.fullname.upper()
		lines = self.lines
		lines.append(c_line_comment("Field: {}  Access: {}".format(fname, f.access.upper())))
		fname = self.blockname.upper() + "_" + fname
		if f.info is not None:
			lines.append(c_line_comment(f.info))
		lines.append("#define {}_LSB  {}".format(fname, f.lsb))
		lines.append("#define {}_BITS {}".format(fname, f.width))
		mask = 0
		for i in range(32):
			if i in range(f.lsb, f.msb + 1):
				mask = mask | (1 << i)
		lines.append("#define {}_MASK {:#x}".format(fname, mask))

def change_ext(fname, ext):
	return ".".join(fname.split(".")[0:-1] + [ext.strip(".")])

if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument("src", help="Source file to read from (pass - for stdin)")
	parser.add_argument("--verilog", "-v", help="Verilog file to write to (pass - for stdout)")
	parser.add_argument("--cheader", "-c", help="C header file to write to (pass - for stdout)")
	parser.add_argument("--all", "-a", action="store_true", help="Generate all output types, with default filenames")
	args = parser.parse_args()
	sfile = None
	vfile = None
	hfile = None
	if args.src == "-" and args.all:
		exit("Cannot use --all with stdin input")
	if args.src == "-":
		sfile = sys.stdin
	else:
		sfile = open(args.src)
	rb = RegBlock.load(sfile)

	if args.verilog is not None or args.all:
		if args.verilog == "-":
			vfile = sys.stdout
		elif args.verilog is None:
			vfile = open(change_ext(args.src, ".v"), "w")
		else:
			vfile = open(args.verilog, "w")
	if args.cheader is not None or args.all:
		if (args.cheader == "-"):
			hfile = sys.stdout
		elif args.verilog is None:
			hfile = open(change_ext(args.src, ".h"), "w")
		else:
			hfile = open(args.cheader, "w")

	if vfile is not None:
		vw = VerilogWriter()
		rb.accept(vw)
		vfile.write(str(vw.v))
	if hfile is not None:
		hw = HeaderWriter()
		rb.accept(hw)
		hfile.write(str(hw))
